Diagonal State Space Models

Diagonal State Spaces are as Effective as Structured State Spaces

Ankit Gupta, Albert Gu, Jonathan Berant.

On the Parameterization and Initialization of Diagonal State Space Models

Albert Gu, Ankit Gupta, Karan Goel, Christopher Ré.


Note: This page is meant as a standalone complement to Section 2 [TODO Link] of the original blog post.

The months following the release of the S4 paper by Gu et al. were characterized by a wave of excitement around the new model, it’s ability to handle extremely long sequences, and generally, what such a departure from Transformer-based architectures could mean. The original authors came out with a follow-up paper applying S4 to audio generation, and weeks later, a completely different group applied S4 to long-range movie clip classification.

Yet, S4 has an intricate algorithm that requires a complicated implementation for diagonal plus low rank (DPLR) state space models (SSM). To motivate this representation, S4 considered the case of diagonal state matrices, and outlined an extremely simple algorithm that can be implemented in just a few lines. However, this was not used because no diagonal SSMs were known that could mathematically model long-range dependencies - S4’s ultimate goal. Instead, S4 used a class of special matrices that could not be diagonalized, but found a way to transform them into almost diagonal form, leading to the more general DPLR representation.

However, at the end of March 2022 - an effective diagonal model was discovered in [Diagonal State Spaces are as Effective as Structured State Spaces] based on approximating S4’s matrix (DSS). This important observation allows diagonal SSMs to be used while preserving the empirical strengths of S4! Diagonal SSMs were further fleshed out in [On the Parameterization and Initialization of Diagonal State Space Models], which implemented S4’s original diagonal algorithm combined with new theory explaining why this particular diagonal initialization can model long-range dependencies (S4D). The rest of this post steps through this much simpler model, an even more structured state space for diagonal matrices.

This post aims to be a complete standalone for Section 2 of the original Annotated S4 post. We’ll still be using Jax with the Flax NN Library for consistency with the original post, and PyTorch versions of DSS and S4D models are publically available.

# import s4.s4 as s4  TODO -- For some reason breaks streamlit...
import s4
from functools import partial
import jax
import jax.numpy as np
from flax import linen as nn
from jax.nn.initializers import lecun_normal, normal
rng = jax.random.PRNGKey(1)

Table of Contents